#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import atexit
import os
import re
import shlex
import shutil
import sys
from subprocess import run
from contextlib import suppress
from glob import glob
from tempfile import NamedTemporaryFile
from typing import NamedTuple, List, Iterable, Set, Optional

if __name__ != "__main__":
    raise Exception(
        "This file is intended to be executed as an executable program. You cannot use it as a module."
        "To run this script, run the ./build command"
    )


class DocBuildError(NamedTuple):
    file_path: Optional[str]
    line_no: Optional[int]
    message: str


build_errors: List[DocBuildError] = []

os.chdir(os.path.dirname(os.path.abspath(__file__)))

ROOT_PROJECT_DIR = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
ROOT_PACKAGE_DIR = os.path.join(ROOT_PROJECT_DIR, "airflow")


def clean_files() -> None:
    print("Removing content of the _build and _api folders")
    with suppress(FileNotFoundError):
        for filename in glob("_build/*"):
            shutil.rmtree(f"_build/{filename}")
    with suppress(FileNotFoundError):
        for filename in glob("_api/*"):
            shutil.rmtree(f"_api/{filename}")
    print("Removed content of the _build and _api folders")


def prepare_directories() -> None:
    if os.path.exists("/.dockerenv"):
        # This script can be run both - in container and outside of it.
        # Here we are inside the container which means that we should (when the host is Linux)
        # fix permissions of the _build and _api folders via sudo.
        # Those files are mounted from the host via docs folder and we might not have permissions to
        # write to those directories (and remove the _api folder).
        # We know we have sudo capabilities inside the container.
        print("Creating the _build and _api folders in case they do not exist")
        run(["sudo", "mkdir", "-pv", "_build"], check=True)
        run(["sudo", "mkdir", "-pv", "_api"], check=True)
        print("Created the _build and _api folders in case they do not exist")

        def restore_ownership() -> None:
            # We are inside the container which means that we should fix back the permissions of the
            # _build and _api folder files, so that they can be accessed by the host user
            # The _api folder should be deleted by then but just in case we should change the ownership
            host_user_id = os.environ["HOST_USER_ID"]
            host_group_id = os.environ["HOST_GROUP_ID"]
            print(f"Changing ownership of docs/_build folder back to {host_user_id}:{host_group_id}")
            run(["sudo", "chown", "-R", f'{host_user_id}:{host_group_id}', "_build"], check=True)
            if os.path.exists("_api"):
                run(["sudo", "chown", "-R", f'{host_user_id}:{host_group_id}', "_api"], check=True)

            print(f"Changed ownership of docs/_build folder back to {host_user_id}:{host_group_id}")

        atexit.register(restore_ownership)

    else:
        # We are outside the container so we simply make sure that the directories exist
        print("Creating the _build and _api folders in case they do not exist")
        run(["mkdir", "-pv", "_build"], check=True)
        run(["mkdir", "-pv", "_api"], check=True)
        print("Creating the _build and _api folders in case they do not exist")


def display_errors_summary() -> None:
    for warning_no, error in enumerate(sorted(build_errors), 1):
        print("=" * 20, f"Error {warning_no:3}", "=" * 20)
        print(error.message)
        print()
        if error.file_path and error.line_no:
            print(f"File path: {error.file_path} ({error.line_no})")
            print()
            print(prepare_code_snippet(error.file_path, error.line_no))
        elif error.file_path:
            print(f"File path: {error.file_path}")

    print("=" * 50)


def check_file_not_contains(file_path: str, pattern: str, message: str) -> None:
    with open(file_path, "rb", 0) as doc_file:
        pattern_compiled = re.compile(pattern)

        for num, line in enumerate(doc_file, 1):
            line_decode = line.decode()
            if re.search(pattern_compiled, line_decode):
                build_errors.append(DocBuildError(file_path=file_path, line_no=num, message=message))


def filter_file_list_by_pattern(file_paths: Iterable[str], pattern: str) -> List[str]:
    output_paths = []
    pattern_compiled = re.compile(pattern)
    for file_path in file_paths:
        with open(file_path, "rb", 0) as text_file:
            text_file_content = text_file.read().decode()
            if re.findall(pattern_compiled, text_file_content):
                output_paths.append(file_path)
    return output_paths


def find_modules(deprecated_only: bool = False) -> Set[str]:
    file_paths = glob(f"{ROOT_PACKAGE_DIR}/**/*.py", recursive=True)
    # Exclude __init__.py
    file_paths = (f for f in file_paths if not f.endswith("__init__.py"))
    if deprecated_only:
        file_paths = filter_file_list_by_pattern(file_paths, r"This module is deprecated.")
    # Make path relative
    file_paths = (os.path.relpath(f, ROOT_PROJECT_DIR) for f in file_paths)
    # Convert filename to module
    modules_names = {file_path.rpartition(".")[0].replace("/", ".") for file_path in file_paths}
    return modules_names


def assert_file_not_contains(file_path: str, pattern: str, message: str) -> None:
    with open(file_path, "rb", 0) as doc_file:
        pattern_compiled = re.compile(pattern)

        for num, line in enumerate(doc_file, 1):
            line_decode = line.decode()
            if re.search(pattern_compiled, line_decode):
                build_errors.append(DocBuildError(file_path=file_path, line_no=num, message=message))


def check_exampleinclude_for_example_dags():
    all_docs_files = glob("**/*rst", recursive=True)

    for doc_file in all_docs_files:
        check_file_not_contains(
            file_path=doc_file,
            pattern=r"literalinclude::.+example_dags",
            message=(
                "literalinclude directive is is prohibited for example DAGs. \n"
                "You should use a exampleinclude directive to include example DAGs."
            )
        )


def check_enforce_code_block():
    all_docs_files = glob("**/*rst", recursive=True)

    for doc_file in all_docs_files:
        assert_file_not_contains(
            file_path=doc_file,
            pattern=r"^.. code::",
            message=(
                "We recommend using the code-block directive instead of the code directive. "
                "The code-block directive is more feature-full."
            )
        )


def prepare_code_snippet(file_path: str, line_no: int, context_lines_count: int=5) -> str:
    def guess_lexer_for_filename(filename):
        from pygments.util import ClassNotFound
        from pygments.lexers import get_lexer_for_filename

        try:
            lexer = get_lexer_for_filename(filename)
        except ClassNotFound:
            from pygments.lexers.special import TextLexer
            lexer = TextLexer()
        return lexer

    with open(file_path) as text_file:
        # Highlight code
        code = text_file.read()
        with suppress(ImportError):
            import pygments

            from pygments.formatters.terminal import TerminalFormatter
            code = pygments.highlight(
                code=code, formatter=TerminalFormatter(), lexer=guess_lexer_for_filename(file_path)
            )

        code_lines = code.split("\n")
        # Prepend line number
        code_lines = [f"{line_no:4} | {line}" for line_no, line in enumerate(code_lines, 1)]
        # # Cut out the snippet
        start_line_no = max(0, line_no - context_lines_count)
        end_line_no = line_no + context_lines_count
        code_lines = code_lines[start_line_no:end_line_no]
        # Join lines
        code = "\n".join(code_lines)
    return code


def parse_sphinx_warnings(warning_text: str) -> List[DocBuildError]:
    sphinx_build_errors = []
    for sphinx_warning in warning_text.split("\n"):
        if not sphinx_warning:
            continue
        warning_parts = sphinx_warning.split(":", 2)
        if len(warning_parts) == 3:
            try:
                sphinx_build_errors.append(
                    DocBuildError(
                        file_path=warning_parts[0], line_no=int(warning_parts[1]), message=warning_parts[2]
                    )
                )
            except Exception:  # pylint: disable=broad-except
                # If an exception occurred while parsing the warning message, display the raw warning message.
                sphinx_build_errors.append(
                    DocBuildError(
                        file_path=None, line_no=None, message=sphinx_warning
                    )
                )
        else:
            sphinx_build_errors.append(DocBuildError(file_path=None, line_no=None, message=sphinx_warning))
    return sphinx_build_errors


def build_sphinx_docs() -> None:
    with NamedTemporaryFile() as tmp_file:
        build_cmd = [
            "sphinx-build",
            "-b",
            "html",
            "-d",
            "_build/doctrees",
            "--color",
            "-w",
            tmp_file.name,
            ".",
            "_build/html",
        ]
        print("Executing cmd: ", " ".join([shlex.quote(c) for c in build_cmd]))

        completed_proc = run(build_cmd)
        if completed_proc.returncode != 0:
            build_errors.append(
                DocBuildError(
                    file_path=None,
                    line_no=None,
                    message=f"Sphinx returned non-zero exit status: {completed_proc.returncode}.",
                )
            )
        tmp_file.seek(0)
        warning_text = tmp_file.read().decode()
        # Remove 7-bit C1 ANSI escape sequences
        warning_text = re.sub(r"\x1B[@-_][0-?]*[ -/]*[@-~]", "", warning_text)
        sphinx_build_errors = parse_sphinx_warnings(warning_text)
        build_errors.extend(sphinx_build_errors)


print("Current working directory: ", os.getcwd())

prepare_directories()
clean_files()

check_enforce_code_block()
check_exampleinclude_for_example_dags()

if build_errors:
    display_errors_summary()
    print()
    print("The documentation has errors. Fix them to build documentation.")
    sys.exit(1)

build_sphinx_docs()

if build_errors:
    display_errors_summary()
    print()
    print("The documentation has errors.")
    sys.exit(1)
